"""
This module generates datasets of points from spheres and cubes.
It provides functionality for sampling, distance calculation, and dataset creation.
"""

from pathlib import Path
from typing import Tuple

import hydra
import torch
from omegaconf import DictConfig, OmegaConf
from scipy.stats import special_ortho_group
from torch.utils.data import Dataset

from diffusion_bandit import utils

# Useful to evaluate expressions in the config file.
OmegaConf.register_new_resolver("eval", eval, use_cache=True)
# Generate a random seed and record it in the config of the experiment.
OmegaConf.register_new_resolver(
    "generate_random_seed", utils.seeding.generate_random_seed, use_cache=True
)


@hydra.main(version_base=None, config_path="configs", config_name="shape_dataset")
def main(config: DictConfig) -> None:
    """
    Main function to generate a cube or sphere dataset.

    This function sets up the projection matrix, generates the dataset,
    and saves it along with the configuration and projection information.
    """
    print(OmegaConf.to_yaml(config))
    utils.seeding.seed_everything(config)

    # Set up projection
    if config.projection.type == "ortho":
        projector = torch.FloatTensor(
            special_ortho_group.rvs(config.dataset.d_ext)[:, : config.dataset.d_int]
        )
    elif config.projection.type == "gaussian":
        projector = torch.randn(config.dataset.d_ext, config.dataset.d_int)
        projector /= config.dataset.d_ext**0.5
    else:
        raise ValueError(f"Unknown projection type: {config.projection.type}")

    # Generate dataset
    x_data = generate_shape_dataset(projector=projector, **config.dataset)

    # Create save directory
    save_dir = Path(config.data_dir)
    save_dir.mkdir(parents=True, exist_ok=True)

    # Save dataset, reward model, projection matrix, and offset
    dataset_path = save_dir / f"{config.save.name}.pt"
    torch.save(
        {
            "x_data": x_data,
            "projector": projector,
            "dataset_config": OmegaConf.to_container(config, resolve=True),
        },
        dataset_path,
    )

    print(f"Dataset saved to {dataset_path}")


def generate_shape_dataset(
    num_samples: int,
    projector: torch.Tensor,
    radius: float,
    surface: bool,
    **kwargs,  # pylint: disable=unused-argument
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Generate a dataset of points from either a sphere or a cube.

    Args:
        shape (str): The shape to generate ('sphere' or 'cube').
        num_samples (int): Number of samples to generate.
        projector (torch.Tensor): Projection matrix.
        offset (torch.Tensor): Offset vector.
        radius (float): Radius of the sphere or half-side length of the cube.
        surface (bool): If True, generate points only on the surface.

    Returns:
        torch.Tensor: Generated points in the extrinsic space.
    """
    x_int = sample_from_sphere(
        num_samples=num_samples,
        d_int=projector.shape[1],
        radius=radius,
        surface=surface,
    )

    x_ext = x_int @ projector.T

    return x_ext


def sample_from_sphere(
    num_samples: int, d_int: int, radius: float = 1, surface: bool = True
) -> torch.Tensor:
    """
    Sample points from a sphere.

    Args:
        num_samples (int): Number of samples to generate.
        d_int (int): Dimension of the sphere.
        radius (float): Radius of the sphere.
        surface (bool): If True, sample only from the surface.

    Returns:
        torch.Tensor: Sampled points from the sphere.
    """
    if surface:
        scaling = torch.full((num_samples,), radius)
    else:
        scaling = radius * torch.rand(num_samples)
    samples = torch.randn(num_samples, d_int)
    samples /= torch.norm(samples, dim=1, keepdim=True)
    samples *= scaling.unsqueeze(1)
    return samples


def distance_to_sphere(
    x_data: torch.Tensor,
    projector: torch.Tensor,
    radius: float = 1,
    surface: bool = True,
) -> torch.Tensor:
    """
    Calculate the L2 distance from points in R^{d_ext} to a sphere embedded via a projection.

    Args:
        x_data (torch.Tensor): Input points in R^{d_ext}.
        projector (torch.Tensor): Projection matrix of size (d_ext, d_int).
        offset (torch.Tensor): Offset vector in R^{d_int}.
        radius (float): Radius of the sphere.
        surface (bool): If True, calculate distance to surface only.

    Returns:
        torch.Tensor: L2 distances from input points to the sphere.
    """
    projections = project_onto_sphere(x_data, projector, radius, surface)
    return torch.linalg.norm(projections - x_data, axis=1)


def project_onto_sphere(
    x_data: torch.Tensor,
    projector: torch.Tensor,
    radius: float,
    surface: bool = True,
) -> torch.Tensor:
    """
    Project points in R^{d_ext} onto a sphere of radius `r`.

    Args:
        x_data (torch.Tensor): Input points in R^{d_ext}.
        projector (torch.Tensor): Projection matrix of size (d_ext, d_int).
        radius (float): Radius of the sphere.
        surface (bool):
            - If True, project onto the sphere's surface.
            - If False, project inside the sphere (points already inside remain unchanged).

    Returns:
        torch.Tensor: Projected points in R^{d_ext}.
    """
    x_low = x_data @ projector  # Shape: (n_samples, d_int)

    norms = torch.linalg.norm(x_low, dim=1, keepdim=True)  # Shape: (n_samples, 1)

    if surface:
        epsilon = 1e-8
        norms_safe = torch.where(
            norms == 0, torch.tensor(epsilon, device=x_data.device), norms
        )
        scaling_factors = radius / norms_safe  # Shape: (n_samples, 1)
        projected_low = x_low * scaling_factors  # Shape: (n_samples, d_int)
        projected_x = projected_low @ projector.t()  # Shape: (n_samples, d_ext)
    else:
        scaling_factors = torch.where(
            norms <= radius, torch.ones_like(norms), radius / norms
        )
        projected_low = x_low * scaling_factors  # Shape: (n_samples, d_int)
        projected_x = projected_low @ projector.t()  # Shape: (n_samples, d_ext)

    return projected_x


# Torch Dataset
class ShapeDataset(Dataset):
    """
    A PyTorch Dataset for shape data.

    Attributes:
        x_data (torch.Tensor): The input data points.
        y_data (torch.Tensor): The corresponding labels or target values.
    """

    def __init__(self, x_data, y_data):
        self.x_data = x_data
        self.y_data = y_data

    def __len__(self):
        return len(self.x_data)

    def __getitem__(self, idx):
        return self.x_data[idx], self.y_data[idx]


class ShapeXDataset(Dataset):
    """
    A PyTorch Dataset for shape data.

    Attributes:
        x_data (torch.Tensor): The input data points.
        y_data (torch.Tensor): The corresponding labels or target values.
    """

    def __init__(self, x_data):
        self.x_data = x_data

    def __len__(self):
        return len(self.x_data)

    def __getitem__(self, idx):
        return self.x_data[idx]


if __name__ == "__main__":
    main()  # pylint: disable=no-value-for-parameter
